{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Partitioners\n",
    "\n",
    "Understand `Partitioner`s interactions with `FederatedDataset`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -q \"flwr-datasets[vision]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What is `Partitioner`?\n",
    "\n",
    "`Partitioner` is an object responsible for dividing a dataset according to a chosen strategy. There are many `Partitioner`s that you can use (see the full list [here](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.html)) and all of them inherit from the `Partitioner` object which is an abstract class providing basic structure and methods that need to be implemented for any new `Partitioner` to integrate with the rest of `Flower Datasets` code. The creation of different `Partitioner` differs, but the behavior is the same = they produce the same type of objects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `IidPartitioner` Creation\n",
    "\n",
    "Let's create (instantiate) the most basic partitioner, [IidPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.IidPartitioner.html#flwr_datasets.partitioner.IidPartitioner) and learn how it interacts with `FederatedDataset`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr_datasets.partitioner import IidPartitioner\n",
    "\n",
    "# Set the partitioner to create 10 partitions\n",
    "partitioner = IidPartitioner(num_partitions=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Right now the partitioner does not have access to any data therefore it has nothing to partition. `FederatedDataset` is responsible for assigning data to a `partitioner`(s).\n",
    "\n",
    "What **part** of the data is assigned to partitioner?\n",
    "\n",
    "In centralized (traditional) ML, there exist a strong concept of the splits of the dataset. Typically you can hear about train/valid/test splits. In FL research, if we don't have an already divided datasets (e.g. by `user_id`), we simulate such division using a centralized dataset. The goal of that operation is to simulate an FL scenario where the data is spread across clients. In Flower Datasets you decide what split of the dataset will be partitioned. You can also resplit the datasets such that you use a more non-custom split, or merge the whole train and test split into a single dataset. That's not a part of this tutorial (if you are curious how to do that see [Divider docs](https://flower.ai/docs/datasets/ref-api/flwr_datasets.preprocessor.Divider.html), [Merger docs](https://flower.ai/docs/datasets/ref-api/flwr_datasets.preprocessor.Merger.html) and `preprocessor` parameter docs of [FederatedDataset](https://flower.ai/docs/datasets/ref-api/flwr_datasets.FederatedDataset.html)).\n",
    "\n",
    "Let's see how you specify the split for partitioning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How do you specify the split to partition?\n",
    "\n",
    "The specification of the split happens as you specify the `partitioners` argument for `FederatedDataset`. It maps `partition_id: str` to the partitioner that will be used for that split of the data. In the example below we're using the `train` split of the `cifar10` dataset to partition.\n",
    "\n",
    "> If you're unsure why/how we chose the name of the `dataset` and how to customize it, see the [first tutorial]((https://flower.ai/docs/datasets/quickstart-tutorial.html))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 5000\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from flwr_datasets import FederatedDataset\n",
    "\n",
    "# Create the federated dataset passing the partitioner\n",
    "fds = FederatedDataset(dataset=\"uoft-cs/cifar10\", partitioners={\"train\": partitioner})\n",
    "\n",
    "# Load the first partition\n",
    "iid_partition = fds.load_partition(partition_id=0)\n",
    "iid_partition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],\n",
       " 'label': [1, 2, 6]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's take a look at the first three samples\n",
    "iid_partition[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use Different `Partitioners`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Why would you need to use different `Partitioner`s?**\n",
    "\n",
    "There are a few ways that the data partitioning is simulated in the literature. `Flower Datasets` let's you work with many different approaches that have been proposed so far. It enables you to simulate partitions with different properties and different levels of heterogeneity and use those settings to evaluate your Federated Learning algorithms.\n",
    "\n",
    "\n",
    "**How to use different `Partitioner`s?**\n",
    "\n",
    "To use a different `Partitioner` you just need to create a different object (note it has typically different parameters that you need to specify). Then you pass it as before to the `FederatedDataset`.\n",
    "\n",
    "<div style=\"max-width:80%; margin-left: auto; margin-right: auto\">\n",
    "    <img src=\"./_static/tutorial-quickstart/partitioner-flexibility.png\" alt=\"Partitioner flexibility display\">\n",
    "</div>\n",
    "See the only changing part in yellow.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Creating non-IID partitions: Use ``PathologicalPartitioner``\n",
    "\n",
    "Now, we are going to create partitions that have only a subset of labels in each partition by using [PathologicalPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.PathologicalPartitioner.html#flwr_datasets.partitioner.PathologicalPartitioner). In this scenario we have the exact control about the number of unique labels on each partition. The smaller the number is the more heterogeneous the division gets. Let's have a look at how it works with `num_classes_per_partition=2`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 2501\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from flwr_datasets.partitioner import PathologicalPartitioner\n",
    "\n",
    "# Set the partitioner to create 10 partitions with 2 classes per partition\n",
    "# Partition using column \"label\" (a column in the huggingface representation of CIFAR-10)\n",
    "pathological_partitioner = PathologicalPartitioner(\n",
    "    num_partitions=10, partition_by=\"label\", num_classes_per_partition=2\n",
    ")\n",
    "\n",
    "# Create the federated dataset passing the partitioner\n",
    "fds = FederatedDataset(\n",
    "    dataset=\"uoft-cs/cifar10\", partitioners={\"train\": pathological_partitioner}\n",
    ")\n",
    "\n",
    "# Load the first partition\n",
    "partition_pathological = fds.load_partition(partition_id=0)\n",
    "partition_pathological"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],\n",
       " 'label': [0, 0, 7]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's take a look at the first three samples\n",
    "partition_pathological[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 7])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# We can use `np.unique` to get a list of the unique labels that are present\n",
    "# in this data partition. As expected, there are just two labels. This means\n",
    "# that this partition has only images with numbers 0 and 7.\n",
    "np.unique(partition_pathological[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Creating non-IID partitions: Use ``DirichletPartitioner``\n",
    "\n",
    "With the [DirichletPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.DirichletPartitioner.html#flwr_datasets.partitioner.DirichletPartitioner), the primary tool for controlling heterogeneity is the `alpha` parameter; the smaller the value gets, the more heterogeneous the federated datasets are. Instead of choosing the exact number of classes on each partition, here we sample the probability distribution from the Dirichlet distribution, which tells how the samples associated with each class will be divided."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'label'],\n",
       "    num_rows: 5433\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from flwr_datasets.partitioner import DirichletPartitioner\n",
    "\n",
    "# Set the partitioner to create 10 partitions with alpha 0.1 (so fairly non-IID)\n",
    "# Partition using column \"label\" (a column in the huggingface representation of CIFAR-10)\n",
    "dirichlet_partitioner = DirichletPartitioner(\n",
    "    num_partitions=10, alpha=0.1, partition_by=\"label\"\n",
    ")\n",
    "\n",
    "# Create the federated dataset passing the partitioner\n",
    "fds = FederatedDataset(\n",
    "    dataset=\"uoft-cs/cifar10\", partitioners={\"train\": dirichlet_partitioner}\n",
    ")\n",
    "\n",
    "# Load the first partition\n",
    "partition_from_dirichlet = fds.load_partition(partition_id=0)\n",
    "partition_from_dirichlet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'img': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,\n",
       "  <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>],\n",
       " 'label': [4, 4, 0, 1, 4]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's take a look at the first five samples\n",
    "partition_from_dirichlet[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Final remarks\n",
    "Congratulations, you now know how to use different `Partitioner`s with `FederatedDataset` in Flower Datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "This is the second quickstart tutorial from the Flower Datasets series. See next tutorials:\n",
    "\n",
    "* [Visualize Label Distribution](https://flower.ai/docs/datasets/tutorial-visualize-label-distribution.html)\n",
    "\n",
    "Previous tutorials:\n",
    "\n",
    "* [Quickstart Basics](https://flower.ai/docs/datasets/tutorial-quickstart.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "flwr",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
